{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "\n",
    "image_size = 784\n",
    "h_dim = 400\n",
    "z_dim = 20\n",
    "num_epochs = 30\n",
    "batch_size = 128\n",
    "learning_rate = 0.001\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torchvision.datasets.MNIST(root='data',\n",
    "                                     train=True,\n",
    "                                     transform=transforms.ToTensor(),\n",
    "                                     download=False)\n",
    " \n",
    "#数据加载\n",
    "data_loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                          batch_size=batch_size, \n",
    "                                          shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_size=1200\n",
    "# 构建判断器\n",
    "D = nn.Sequential(\n",
    "    nn.Linear(image_size, hidden_size),\n",
    "    nn.LeakyReLU(0.2),\n",
    "    nn.Linear(hidden_size, hidden_size),\n",
    "    nn.LeakyReLU(0.2),\n",
    "    nn.Linear(hidden_size, 1),\n",
    "    nn.Sigmoid())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_size=784\n",
    "# 构建生成器，这个相当于AVE中的解码器 \n",
    "G = nn.Sequential(\n",
    "    nn.Linear(latent_size, hidden_size),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(hidden_size, hidden_size),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(hidden_size, image_size),\n",
    "    nn.Tanh())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Linear(in_features=784, out_features=1200, bias=True)\n",
       "  (1): ReLU()\n",
       "  (2): Linear(in_features=1200, out_features=1200, bias=True)\n",
       "  (3): ReLU()\n",
       "  (4): Linear(in_features=1200, out_features=784, bias=True)\n",
       "  (5): Tanh()\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device=torch.device('cuda:0')\n",
    "D.to(device)\n",
    "G.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion=nn.BCELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reset_grad():\n",
    "    D.zero_grad()\n",
    "    G.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)\n",
    "g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_step=200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [0/30], Step [200/200], d_loss: 0.1874, g_loss: 7.9065, D(x): 0.97, D(G(z)): 0.13\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    for i, (images, _) in enumerate(data_loader):\n",
    "        images = images.reshape(batch_size, -1).to(device)\n",
    "        \n",
    "        # 定义图像是真或假的标签\n",
    "        real_labels = torch.ones(batch_size, 1).to(device)\n",
    "        fake_labels = torch.zeros(batch_size, 1).to(device)\n",
    " \n",
    "        # ================================================================== #\n",
    "        #                      训练判别器                                    #\n",
    "        # ================================================================== #\n",
    " \n",
    "        # 定义判别器对真图片的损失函数\n",
    "        outputs = D(images)\n",
    "        d_loss_real = criterion(outputs, real_labels)\n",
    "        real_score = outputs\n",
    "        \n",
    "        # 定义判别器对假图片（即由潜在空间点生成的图片）的损失函数\n",
    "        z = torch.randn(batch_size, latent_size).to(device)\n",
    "        fake_images = G(z)\n",
    "        outputs = D(fake_images)\n",
    "        d_loss_fake = criterion(outputs, fake_labels)\n",
    "        fake_score = outputs        \n",
    "     \n",
    "        # 得到判别器总的损失函数\n",
    "        d_loss = d_loss_real + d_loss_fake\n",
    "        \n",
    "        # 对生成器、判别器的梯度清零        \n",
    "        reset_grad()\n",
    "        d_loss.backward()\n",
    "        d_optimizer.step()\n",
    "        \n",
    "        # ================================================================== #\n",
    "        #                        训练生成器                                  #\n",
    "        # ================================================================== #\n",
    " \n",
    "        # 定义生成器对假图片的损失函数，这里我们要求\n",
    "        #判别器生成的图片越来越像真图片，故损失函数中\n",
    "        #的标签改为真图片的标签，即希望生成的假图片，\n",
    "        #越来越靠近真图片\n",
    "        z = torch.randn(batch_size, latent_size).to(device)\n",
    "        fake_images = G(z)\n",
    "        outputs = D(fake_images)\n",
    "        \n",
    "        \n",
    "        g_loss = criterion(outputs, real_labels)\n",
    "        \n",
    "        # 对生成器、判别器的梯度清零\n",
    "        #进行反向传播及运行生成器的优化器\n",
    "        reset_grad()\n",
    "        g_loss.backward()\n",
    "        g_optimizer.step()\n",
    "        \n",
    "        if (i+1) % 200 == 0:\n",
    "            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' \n",
    "                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), \n",
    "                          real_score.mean().item(), fake_score.mean().item()))\n",
    "    \n",
    "    # 保存真图片\n",
    "    if (epoch+1) == 1:\n",
    "        images = images.reshape(images.size(0), 1, 28, 28)\n",
    "        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))\n",
    "    \n",
    "    # 保存假图片\n",
    "    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)\n",
    "    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))\n",
    " \n",
    "# 保存模型\n",
    "torch.save(G.state_dict(), 'G.ckpt')\n",
    "torch.save(D.state_dict(), 'D.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
